# -*- coding: utf-8 -*-
"""
Created on Sat Apr  9 09:38:56 2022

@author: S1mple
"""
from PIL import Image
import torch
from torch.utils.data import Dataset
import random

class MyDataset(Dataset):
    
    #自定义数据集
    
    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

        
    def __len__(self):
       return len(self.images_path)
   
    def __getitem__(self, item):
       
        # print(self.class_list)
         img = Image.open(self.images_path[item])
         # RGB为彩色图片，L为灰度图片
         if img.mode != 'RGB':
             img = img.convert('RGB')     #转化为灰度图片
         label = self.images_class[item]
         if self.transform is not None:
             img = self.transform(img)
         return img, label
     
    
    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))
    
        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels
